# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content")
load("//build_tools/bazel:iree_bitcode_library.bzl", "iree_bitcode_library", "iree_link_bitcode")

package(
    default_visibility = ["//visibility:public"],
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

#===------------------------------------------------------------------------===#
# UKernel bitcode files
#===------------------------------------------------------------------------===#

iree_cmake_extra_content(
    content = """
iree_compiler_targeting_iree_arch(_IREE_UKERNEL_BITCODE_BUILD_X86_64 "x86_64")
if(_IREE_UKERNEL_BITCODE_BUILD_X86_64)
""",
    inline = True,
)

# All headers transitively included by code in this directory. Bazel-only.
UKERNEL_X86_64_INTERNAL_HEADERS = [
    "common_x86_64.h",
    "mmt4d_x86_64_internal.h",
    "mmt4d_x86_64_tiles.inl",
    "pack_x86_64_internal.h",
    "unpack_x86_64_internal.h",
    "//runtime/src/iree/builtins/ukernel:internal_headers_filegroup",
    "//runtime/src/iree/schemas:cpu_data_headers_filegroup",
]

iree_bitcode_library(
    name = "ukernel_bitcode_arch_x86_64_entry_points",
    srcs = [
        "mmt4d_x86_64_entry_point.c",
        "pack_x86_64_entry_point.c",
        "unpack_x86_64_entry_point.c",
    ],
    arch = "x86_64",
    internal_hdrs = UKERNEL_X86_64_INTERNAL_HEADERS,
)

UKERNEL_X86_64_AVX2_FMA_COPTS = [
    "-mavx",
    "-mavx2",
    "-mfma",
    "-mf16c",
]

iree_bitcode_library(
    name = "ukernel_bitcode_arch_x86_64_avx2_fma",
    srcs = [
        "mmt4d_x86_64_avx2_fma.c",
        "pack_x86_64_avx2_fma.c",
        "unpack_x86_64_avx2_fma.c",
    ],
    arch = "x86_64",
    copts = UKERNEL_X86_64_AVX2_FMA_COPTS,
    internal_hdrs = UKERNEL_X86_64_INTERNAL_HEADERS,
)

UKERNEL_X86_64_AVX512_BASE_COPTS = UKERNEL_X86_64_AVX2_FMA_COPTS + [
    "-mavx512f",
    "-mavx512vl",
    "-mavx512cd",
    "-mavx512bw",
    "-mavx512dq",
]

iree_bitcode_library(
    name = "ukernel_bitcode_arch_x86_64_avx512_base",
    srcs = [
        "mmt4d_x86_64_avx512_base.c",
        "pack_x86_64_avx512_base.c",
        "unpack_x86_64_avx512_base.c",
    ],
    arch = "x86_64",
    copts = UKERNEL_X86_64_AVX512_BASE_COPTS,
    internal_hdrs = UKERNEL_X86_64_INTERNAL_HEADERS,
)

UKERNEL_X86_64_AVX512_VNNI_COPTS = UKERNEL_X86_64_AVX512_BASE_COPTS + [
    "-mavx512vnni",
]

iree_bitcode_library(
    name = "ukernel_bitcode_arch_x86_64_avx512_vnni",
    srcs = [
        "mmt4d_x86_64_avx512_vnni.c",
    ],
    arch = "x86_64",
    copts = UKERNEL_X86_64_AVX512_VNNI_COPTS,
    internal_hdrs = UKERNEL_X86_64_INTERNAL_HEADERS,
)

UKERNEL_X86_64_AVX512_BF16_COPTS = UKERNEL_X86_64_AVX512_BASE_COPTS + [
    "-mavx512bf16",
]

iree_bitcode_library(
    name = "ukernel_bitcode_arch_x86_64_avx512_bf16",
    srcs = [
        "mmt4d_x86_64_avx512_bf16.c",
    ],
    arch = "x86_64",
    copts = UKERNEL_X86_64_AVX512_BF16_COPTS,
    internal_hdrs = UKERNEL_X86_64_INTERNAL_HEADERS,
)

iree_link_bitcode(
    name = "ukernel_bitcode_arch_x86_64",
    bitcode_files = [
        "ukernel_bitcode_arch_x86_64_entry_points.bc",
        "ukernel_bitcode_arch_x86_64_avx2_fma.bc",
        "ukernel_bitcode_arch_x86_64_avx512_base.bc",
        "ukernel_bitcode_arch_x86_64_avx512_vnni.bc",
        "ukernel_bitcode_arch_x86_64_avx512_bf16.bc",
    ],
)

iree_cmake_extra_content(
    content = """
elseif(IREE_BUILD_COMPILER AND IREE_TARGET_BACKEND_LLVM_CPU)
iree_make_empty_file("${CMAKE_CURRENT_BINARY_DIR}/ukernel_bitcode_arch_x86_64.bc")
endif()  # _IREE_UKERNEL_BITCODE_BUILD_X86_64
""",
    inline = True,
)
